import faulthandler;

faulthandler.enable()

import copy
import gym
import random
import numpy as np
import statistics
import itertools

# Import environment
import improved_humanoid
from SnapshotENV import SnapshotEnv


class TunableCEM:
    """CEM with tunable hyperparameters for experimentation"""

    def __init__(self, action_dim, **kwargs):
        self.action_dim = action_dim

        # Hyperparameters with defaults
        self.horizon = kwargs.get('horizon', 8)
        self.pop_size = kwargs.get('pop_size', 150)
        self.elite_frac = kwargs.get('elite_frac', 0.1)
        self.n_elite = max(1, int(self.pop_size * self.elite_frac))
        self.cem_iters = kwargs.get('cem_iters', 10)
        self.init_std = kwargs.get('init_std', 0.6)
        self.std_decay = kwargs.get('std_decay', 0.9)
        self.min_std = kwargs.get('min_std', 0.15)
        self.bounds = kwargs.get('bounds', (-1.0, 1.0))

        # Advanced options
        self.use_momentum = kwargs.get('use_momentum', False)
        self.momentum_alpha = kwargs.get('momentum_alpha', 0.1)
        self.use_antithetic = kwargs.get('use_antithetic', False)
        self.regularization = kwargs.get('regularization', 0.0)

        self.reset_distribution()

        # For momentum
        if self.use_momentum:
            self.prev_mean = np.zeros_like(self.mean)

    def reset_distribution(self):
        self.mean = np.zeros((self.horizon, self.action_dim), dtype=np.float32)
        self.std = np.full((self.horizon, self.action_dim), self.init_std, dtype=np.float32)

    def sample_sequences(self):
        """Sample action sequences with optional improvements"""
        sequences = []

        # Regular sampling
        n_regular = self.pop_size // 2 if self.use_antithetic else self.pop_size

        for _ in range(n_regular):
            sequence = []
            for t in range(self.horizon):
                action = np.random.normal(self.mean[t], self.std[t])
                action = np.clip(action, self.bounds[0], self.bounds[1])
                sequence.append(action.astype(np.float32))
            sequences.append(sequence)

        # Antithetic sampling (negative samples)
        if self.use_antithetic:
            for _ in range(self.pop_size - n_regular):
                sequence = []
                for t in range(self.horizon):
                    noise = np.random.normal(0, self.std[t])
                    action = self.mean[t] - noise  # Antithetic sample
                    action = np.clip(action, self.bounds[0], self.bounds[1])
                    sequence.append(action.astype(np.float32))
                sequences.append(sequence)

        return sequences

    def update_distribution(self, sequences, scores):
        """Update distribution with regularization and momentum"""
        # Get elite sequences
        elite_indices = np.argsort(scores)[-self.n_elite:]
        elite_sequences = [sequences[i] for i in elite_indices]

        # Compute new mean and std
        new_mean = np.zeros_like(self.mean)
        new_std = np.zeros_like(self.std)

        for t in range(self.horizon):
            elite_actions = np.array([seq[t] for seq in elite_sequences])
            new_mean[t] = np.mean(elite_actions, axis=0)
            new_std[t] = np.std(elite_actions, axis=0, ddof=1)

        # Apply regularization
        if self.regularization > 0:
            reg_factor = self.regularization
            new_std = new_std + reg_factor * self.init_std

        # Apply momentum to mean
        if self.use_momentum:
            self.mean = (1 - self.momentum_alpha) * new_mean + self.momentum_alpha * self.prev_mean
            self.prev_mean = self.mean.copy()
        else:
            self.mean = new_mean

        # Update std with decay and minimum
        self.std = np.maximum(self.std_decay * new_std, self.min_std)

    def plan_action(self, env, snapshot):
        """Plan action using CEM"""
        for iteration in range(self.cem_iters):
            # Sample sequences
            sequences = self.sample_sequences()
            scores = []

            # Evaluate sequences
            for seq in sequences:
                env.load_snapshot(snapshot)
                total_reward = 0.0
                discount = 1.0

                for action in seq:
                    obs, r, done, _ = env.step(action)
                    total_reward += r * discount
                    discount *= 0.99
                    if done:
                        break
                scores.append(total_reward)

            # Update distribution
            self.update_distribution(sequences, scores)

        return self.mean[0].copy()

    def shift_horizon(self):
        """Shift planning horizon"""
        self.mean[:-1] = self.mean[1:]
        self.mean[-1] = 0.0
        # Keep std as is for stability


def evaluate_cem_config(config, num_seeds=5, max_steps=100):
    """Evaluate a CEM configuration"""

    envname = "ImprovedHumanoid-v0"
    stoch_kwargs = {
        "action_noise_scale": 0.03,
        "dynamics_noise_scale": 0.02,
        "obs_noise_scale": 0.01
    }

    seed_returns = []

    for seed in range(num_seeds):
        random.seed(seed)
        np.random.seed(seed)

        # Create environments
        planning_env = SnapshotEnv(gym.make(envname, **stoch_kwargs).env)
        test_env = SnapshotEnv(gym.make(envname, **stoch_kwargs).env)
        test_env.reset()

        # Create CEM with config
        cem = TunableCEM(action_dim=17, **config)

        # Run episode
        total_reward = 0.0
        discount = 1.0

        for step in range(max_steps):
            snapshot = test_env.get_snapshot()
            action = cem.plan_action(planning_env, snapshot)

            obs, r, done, _ = test_env.step(action)
            total_reward += r * discount
            discount *= 0.99

            if done:
                break

            cem.shift_horizon()

        seed_returns.append(total_reward)

        planning_env.close()
        test_env.close()

    return statistics.mean(seed_returns), statistics.pstdev(seed_returns)


def hyperparameter_search():
    """Search for good CEM hyperparameters"""

    print("CEM Hyperparameter Search for ImprovedHumanoid-v0")
    print("=" * 60)

    # Define search space
    search_space = {
        'horizon': [5, 8, 12],
        'pop_size': [100, 150, 200],
        'elite_frac': [0.1, 0.15, 0.2],
        'cem_iters': [8, 10, 12],
        'init_std': [0.4, 0.6, 0.8],
        'std_decay': [0.85, 0.9, 0.95],
        'min_std': [0.1, 0.15, 0.2]
    }

    # Baseline configuration
    baseline_config = {
        'horizon': 8,
        'pop_size': 150,
        'elite_frac': 0.1,
        'cem_iters': 10,
        'init_std': 0.6,
        'std_decay': 0.9,
        'min_std': 0.15
    }

    print("Testing baseline configuration...")
    baseline_mean, baseline_std = evaluate_cem_config(baseline_config)
    print(f"Baseline: {baseline_mean:.2f} ± {baseline_std:.2f}")

    best_config = baseline_config.copy()
    best_score = baseline_mean

    # Grid search (one parameter at a time to keep it manageable)
    for param_name, param_values in search_space.items():
        print(f"\nTuning {param_name}: {param_values}")

        for value in param_values:
            if value == baseline_config[param_name]:
                continue  # Skip baseline value

            config = best_config.copy()
            config[param_name] = value

            mean_score, std_score = evaluate_cem_config(config, num_seeds=3)  # Fewer seeds for speed
            print(f"  {param_name}={value}: {mean_score:.2f} ± {std_score:.2f}")

            if mean_score > best_score:
                best_score = mean_score
                best_config[param_name] = value
                print(f"    New best! Score improved to {best_score:.2f}")

    print("\n" + "=" * 60)
    print("HYPERPARAMETER SEARCH RESULTS")
    print("=" * 60)

    print("Best configuration found:")
    for param, value in best_config.items():
        print(f"  {param}: {value}")

    print(f"\nBest score: {best_score:.2f}")
    print(f"Improvement over baseline: {best_score - baseline_mean:.2f}")

    # Final evaluation with more seeds
    print("\nFinal evaluation with more seeds...")
    final_mean, final_std = evaluate_cem_config(best_config, num_seeds=10)
    print(f"Final score: {final_mean:.2f} ± {final_std:.2f}")

    return best_config


def test_cem_variants():
    """Test different CEM variants"""

    print("Testing CEM Variants")
    print("=" * 40)

    base_config = {
        'horizon': 8,
        'pop_size': 150,
        'elite_frac': 0.1,
        'cem_iters': 10,
        'init_std': 0.6,
        'std_decay': 0.9,
        'min_std': 0.15
    }

    variants = [
        ("Standard CEM", {}),
        ("With Momentum", {'use_momentum': True, 'momentum_alpha': 0.1}),
        ("With Antithetic", {'use_antithetic': True}),
        ("With Regularization", {'regularization': 0.05}),
        ("Combined", {'use_momentum': True, 'use_antithetic': True, 'regularization': 0.02})
    ]

    results = []

    for name, variant_config in variants:
        config = base_config.copy()
        config.update(variant_config)

        print(f"Testing {name}...")
        mean_score, std_score = evaluate_cem_config(config, num_seeds=5)
        results.append((name, mean_score, std_score))
        print(f"  {name}: {mean_score:.2f} ± {std_score:.2f}")

    print("\nVariant Comparison:")
    print("-" * 40)
    results.sort(key=lambda x: x[1], reverse=True)
    for i, (name, mean, std) in enumerate(results):
        print(f"{i + 1}. {name}: {mean:.2f} ± {std:.2f}")

    return results


if __name__ == "__main__":
    print("CEM Hyperparameter Tuning Suite")
    print("Choose an option:")
    print("1. Run hyperparameter search")
    print("2. Test CEM variants")
    print("3. Quick baseline test")

    choice = input("Enter choice (1-3): ").strip()

    if choice == "1":
        best_config = hyperparameter_search()
    elif choice == "2":
        results = test_cem_variants()
    elif choice == "3":
        config = {
            'horizon': 8,
            'pop_size': 150,
            'elite_frac': 0.1,
            'cem_iters': 10,
            'init_std': 0.6,
            'std_decay': 0.9,
            'min_std': 0.15
        }
        print("Running quick baseline test...")
        mean, std = evaluate_cem_config(config, num_seeds=3, max_steps=50)
        print(f"Baseline CEM: {mean:.2f} ± {std:.2f}")
    else:
        print("Invalid choice. Running quick test...")
        config = {'horizon': 5, 'pop_size': 100, 'elite_frac': 0.15, 'cem_iters': 8}
        mean, std = evaluate_cem_config(config, num_seeds=3, max_steps=50)
        print(f"Quick CEM test: {mean:.2f} ± {std:.2f}")